// Copyright 2006-2020 Glenn McIntosh
// licensed under the GNU General Public Licence version 3
#pragma once

/** @file kalman.h */

// include files
#include "linear.h"
#include "matrix.h"
#include <cmath>

namespace math
{
/** real type */
using real = double;

/** Kalman filter. */
template<int N, int M> class KalmanUnscented
{
	static constexpr real alpha{0.001F}, kappa{0.F}, beta{2.F};
	static constexpr real lambda{alpha*alpha*(N+kappa)-N}, gamma{sqrt(N+lambda)};
public:
	/** update the filter
	@param zM set of measurements
	@param state state update function
	@param measurement measurement update function
	@param post post calculation update function
	@param shock values added to state covariance
	@param noise values added to input covariance
	*/
	template<typename State, typename Measurement, typename Post> void update(Array<M> zM, State state, Measurement measurement, Post post, const Array<N> &shock, const Array<M> &noise);
private: // constants
	static constexpr real weightM0{lambda/(N+lambda)}, weightM1{1/(N+lambda)/2};
    static constexpr real weightC0{lambda/(N+lambda)+(1-alpha*alpha+beta)}, weightC1{1/(N+lambda)/2};
private: // state and variance
	Array<N> x{{}};
	Matrix<N, N> xVar{{}};
};

template<int N, int M>
template<typename State, typename Measurement, typename Post>
void KalmanUnscented<N, M>::update(Array<M> zM, State state, Measurement measurement, Post post, const Array<N> &shock, const Array<M> &noise)
{
	// generate set of states (unscented transform)
	Matrix<N, N> xSD;
 	xSD = xVar;
	choleskyDecompose(xSD);
    Array<N> sx[1+N*2];
	sx[0] = x;
	for (size_t j = 0; j < N; ++j)
		for (size_t i = 0; i < N; ++i)
		{
			sx[1+i][j] = x[j] + gamma*xSD[j][i];
			sx[1+N+i][j] = x[j] - gamma*xSD[j][i];
		}

    // new state estimate by propagating points through transition function
    for (int i = 0; i < 2*N+1; ++i)
        state(sx[i]);

    // calculate state mean
	for (size_t j = 0; j < N; ++j)
	{
		real sum = 0;
		for (size_t i = 1; i < 1+N*2; ++i)
			sum += sx[i][j];
		x[j] = weightM0*sx[0][j] + weightM1*sum;
	}

    // calculate state variance
	for (size_t j = 0; j < N; ++j) for (size_t k = 0; k <= j; ++k)
	{
		real sum = 0;
		for (size_t i = 1; i < 1+N*2; ++i)
			sum += (sx[i][j]-x[j])*(sx[i][k]-x[k]);
		xVar[j][k] = xVar[k][j] = weightC0*(sx[0][j]-x[j])*(sx[0][k]-x[k]) + weightC1*sum;
	}

    // add process shock
	for (size_t j = 0; j < N; ++j)
        xVar[j][j] += shock[j];

    // generate set of states (unscented transform)
	xSD = xVar;
	choleskyDecompose(xSD);
	sx[0] = x;
	for (size_t j = 0; j < N; ++j)
		for (size_t i = 0; i < N; ++i)
		{
			sx[1+i][j] = x[j] + gamma*xSD[j][i];
			sx[1+N+i][j] = x[j] - gamma*xSD[j][i];
		}

    // measurement estimate by propagating points through measurement function
    Array<M> sz[1+N*2];
	for (size_t i = 0; i < 1+N*2; ++i)
        measurement(sz[i], sx[i]);

    // calculate measurement mean
	Array<M> z;
	for (size_t j = 0; j < M; ++j)
	{
		real sum = 0;
		for (size_t i = 1; i < 1+N*2; ++i)
			 sum += sz[i][j];
		z[j] = weightM0*sz[0][j] + weightM1*sum;
	}

    // calculate measurement variance
	Matrix<M, M> zVar;
	for (size_t j = 0; j < M; ++j) for (size_t k = 0; k <= j; ++k)
	{
		real sum = 0;
		for (size_t i = 1; i < 1+N*2; ++i)
			sum += (sz[i][j]-z[j])*(sz[i][k]-z[k]);
		zVar[j][k] = zVar[k][j] = weightC0*(sz[0][j]-z[j])*(sz[0][k]-z[k]) + weightC1*sum;
	}

    // add measurement noise
	for (size_t j = 0; j < M; ++j)
        zVar[j][j] += noise[j];

    // calculate state/measurement variance
	Matrix<N, M> xzVar;
	for (size_t j = 0; j < N; ++j) for (size_t k = 0; k < M; ++k)
	{
		real sum = 0;
		for (size_t i = 1; i < 1+N*2; ++i)
			sum += (sx[i][j]-x[j])*(sz[i][k]-z[k]);
		xzVar[j][k] = weightC0*(sx[0][j]-x[j])*(sz[0][k]-z[k]) + weightC1*sum;
	}

    // calculate Kalman gain
	Matrix<N, M> gain;
	gain = xzVar * invert(zVar);

    // calculate state mean
	x = x + gain * (zM - z);

    // calculate state variance
	xVar = xVar - xzVar * transpose(gain);

    // update zero orientation
    post(x);
}
}
